# Copyright 2020 - 2022 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from monai.data import CacheDataset, DataLoader, Dataset, DistributedSampler, SmartCacheDataset, load_decathlon_datalist
from monai.transforms import (
    AddChanneld,
    AsChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    RandCropByPosNegLabeld,
    RandSpatialCropSamplesd,
    ScaleIntensityRanged,
    Spacingd,
    SpatialPadd,
    ToTensord,
)
import json
from monai.data.image_reader import NibabelReader


# This is the load the GBR_T1T2_matched_image.json
def load_T1T2matched_datalist(args):
    with open(args.split_json, 'r') as f:
        fold = json.load(f)
    if args.rank == 0:
        print("load json keys: ", fold.keys())
    training_images = fold['training'] # Should be list
    validation_images = fold['validation']
    # training_images = {i: image for i, image in enumerate(training_images)}
    # training_images = {i: image for i, image in enumerate(validation_images)}
    return {'training': training_images,
            'validation': validation_images} # modified bug here

    
# This is the load the T1_T2_folds.json
def load_T1T210K_datalist(args):
    with open(args.split_json, 'r') as f:
        folds = json.load(f)
    #print(folds.keys())
    training_images   =  folds['fold_0']['training']  # Should be list
    validation_images = folds['fold_0']['validation']
    t1_path = args.t1_path #'/red/ruogu.fang/UKB/data/Brain/20252_T1_NIFTI/T1_unzip'
    t2_path = args.t2_path #'/red/ruogu.fang/UKB/data/Brain/20253_T2_NIFTI/T2_unzip'
    training = {}
    for i, image in enumerate(training_images):
        image_t1 = t1_path + '/' + image + "_20252_2_0/T1_brain_to_MNI.nii.gz"
        image_t2 = t2_path + '/' + image + "_20253_2_0/T2_FLAIR/T2_FLAIR_brain_to_MNI.nii.gz"
        training[i] = {"image": [image_t1, image_t2]}

    validation = {}
    for i, image in enumerate(validation_images):
        image_t1 = t1_path + '/' + image + "_20252_2_0/T1_brain_to_MNI.nii.gz"
        image_t2 = t2_path + '/' + image + "_20253_2_0/T2_FLAIR/T2_FLAIR_brain_to_MNI.nii.gz"
        validation[i] = {"image": [image_t1, image_t2]}
    return {'training': training,
            'validation': validation}
    
# This is the load the mixted T1_T2_folds.json
def load_T1T210K_mixed_datalist(args):
    with open(args.split_json, 'r') as f:
        folds = json.load(f)
    #print(folds.keys())
    training_images   =  folds['fold_0']['training']  # Should be list
    validation_images = folds['fold_0']['validation']
    t1_path = args.t1_path #'/red/ruogu.fang/UKB/data/Brain/20252_T1_NIFTI/T1_unzip'
    t2_path = args.t2_path #'/red/ruogu.fang/UKB/data/Brain/20253_T2_NIFTI/T2_unzip'

    training = {}
    image_id = 0
    for i, image in enumerate(training_images):
        image_t1 = t1_path + '/' + image + "_20252_2_0/T1_brain_to_MNI.nii.gz"
        image_t2 = t2_path + '/' + image + "_20253_2_0/T2_FLAIR/T2_FLAIR_brain_to_MNI.nii.gz"
        training[image_id] = {"image": image_t1}
        image_id += 1
        training[image_id] = {"image": image_t2}
        image_id += 1

    image_id = 0    
    validation = {}
    for i, image in enumerate(validation_images):
        image_t1 = t1_path + '/' + image + "_20252_2_0/T1_brain_to_MNI.nii.gz"
        image_t2 = t2_path + '/' + image + "_20253_2_0/T2_FLAIR/T2_FLAIR_brain_to_MNI.nii.gz"
        validation[image_id] = {"image": image_t1}
        image_id += 1
        validation[image_id] = {"image": image_t2}
        image_id += 1
    return {'training': training,
            'validation': validation}
    
def get_T1T2_dataloaders(args,  num_workers = 4):

    if args.T1T2_10k:
        datalist = load_T1T210K_datalist(args)
    elif args.T1T2_10k_mixed:
        datalist = load_T1T210K_mixed_datalist(args)
    else : 
        datalist = load_T1T2matched_datalist(args)
        
    if args.modality == "T1":
        training_datalist   = [{"image":subject["image"][0]} for subject in datalist['training']] 
        validation_datalist = [{"image":subject["image"][0]} for subject in datalist['validation']] 
    elif args.modality == "T2":
        training_datalist   = [{"image":subject["image"][1]} for subject in datalist['training']] 
        validation_datalist = [{"image":subject["image"][1]} for subject in datalist['validation']] 
    elif args.modality == "T1T2":
        training_datalist, validation_datalist = datalist['training'], datalist['validation']
    else:
        raise ValueError("Unsupported modality")
  
    if args.rank == 0:
        print(f"Training on {len(training_datalist)} {args.modality} images.")
        print(f"Validation on {len(validation_datalist)} {args.modality} images.")

    if args.modality == "T1T2" and args.in_channels == 2:
        transforms = Compose(
            [
                LoadImaged(keys=["image"], reader=NibabelReader),
                # AddChanneld(keys=["image"]), # is it needed? 
                Orientationd(keys=["image"], axcodes="RAS"), # is it needed? 
                ScaleIntensityRanged(
                keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
                ),
                SpatialPadd(keys=["image"], spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
                CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z]),
                RandSpatialCropSamplesd(
                    keys=["image"],
                    roi_size=[args.roi_x, args.roi_y, args.roi_z],
                    num_samples=args.sw_batch_size,
                    random_center=True,
                    random_size=False,
                ),
                ToTensord(keys=["image"]),
            ]
        )
    else: 
        transforms = Compose(
            [
                LoadImaged(keys=["image"], reader=NibabelReader),
                AddChanneld(keys=["image"]), # is it needed only for single modality? 
                Orientationd(keys=["image"], axcodes="RAS"), # is it needed? 
                ScaleIntensityRanged(
                keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
                ),
                SpatialPadd(keys=["image"], spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
                CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z]),
                RandSpatialCropSamplesd(
                    keys=["image"],
                    roi_size=[args.roi_x, args.roi_y, args.roi_z],
                    num_samples=args.sw_batch_size,
                    random_center=True,
                    random_size=False,
                ),
                ToTensord(keys=["image"]),
            ]
        )
        

    if args.cache_dataset:
        print("Using MONAI Cache Dataset")
        training_dataset = CacheDataset(data=training_datalist, transform=transforms, cache_rate=0.5, num_workers=num_workers)
    elif args.smartcache_dataset:
        print("Using MONAI SmartCache Dataset")
        training_dataset = SmartCacheDataset(
            data=training_datalist,
            transform=transforms,
            replace_rate=1.0,
            cache_num=2 * args.batch_size * args.sw_batch_size,
        )
    else:
        if args.rank == 0:
            print("Using generic dataset")
        training_dataset = Dataset(data=training_datalist, transform=transforms)

    if args.distributed:
        train_sampler = DistributedSampler(dataset=training_dataset, even_divisible=True, shuffle=True)
    else:
        train_sampler = None

    train_loader = DataLoader(
        training_dataset, 
        batch_size=args.batch_size, 
        num_workers=num_workers, 
        sampler=train_sampler, 
        drop_last=True
    )

    val_ds = Dataset(data=validation_datalist, transform=transforms)


    # YY
    if args.distributed:
        val_sampler = DistributedSampler(dataset=val_ds)
    else:
        val_sampler = None
        
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        num_workers=num_workers, # YY
        shuffle=(val_sampler is None), #YY
        sampler=val_sampler, #YY
        drop_last=True,
        # pin_memory=True,
    )

    args.train_ds_len = len(training_dataset) # YY 
    args.val_ds_len = len(val_ds)     # YY

    # val_loader = DataLoader(val_ds, batch_size=args.batch_size, num_workers=num_workers, shuffle=False, drop_last=True)

    return train_loader, val_loader

def get_loader(args):
    splits1 = "/dataset_LUNA16_0.json"
    splits2 = "/dataset_TCIAcovid19_0.json"
    splits3 = "/dataset_HNSCC_0.json"
    splits4 = "/dataset_TCIAcolon_UFL.json"  # "/dataset_TCIAcolon_v2_0.json"
    splits5 = "/dataset_LIDC_0.json"
    
    list_dir = "./jsons"
    list_dir = args.workdir + "/jsons"   # put .json in this dir! args.logdir=/mnt
    
    jsonlist1 = list_dir + splits1
    jsonlist2 = list_dir + splits2
    jsonlist3 = list_dir + splits3
    jsonlist4 = list_dir + splits4
    jsonlist5 = list_dir + splits5
    
    datadir1 = args.workdir + "/dataset/dataset1" 
    datadir2 = args.workdir + "/dataset/dataset2"    
    datadir3 = args.workdir + "/dataset/dataset3"
    datadir4 = args.workdir + "/dataset/dataset4"
    datadir5 = args.workdir + "/dataset/dataset8"
    
    # load datalist
    num_workers = 4
    datalist1 = load_decathlon_datalist(jsonlist1, False, "training", base_dir=datadir1)
    new_datalist1 = []
    for item in datalist1:
        item_dict = {"image": item["image"]}
        new_datalist1.append(item_dict)
    if args.rank == 0: 
        print("Dataset 1 LUNA16: number of data: {}".format(len(new_datalist1)))

    datalist2 = load_decathlon_datalist(jsonlist2, False, "training", base_dir=datadir2)
    if args.rank == 0:
        print(f"[{args.rank}] " + "Dataset 2 Covid 19: number of data: {}".format(len(datalist2)))
    datalist3 = load_decathlon_datalist(jsonlist3, False, "training", base_dir=datadir3)
    if args.rank == 0:
        print("Dataset 3 HNSCC: number of data: {}".format(len(datalist3)))
    datalist4 = load_decathlon_datalist(jsonlist4, False, "training", base_dir=datadir4)
    if args.rank == 0:
        print("Dataset 4 TCIA Colon: number of data: {}".format(len(datalist4)))
    datalist5 = load_decathlon_datalist(jsonlist5, False, "training", base_dir=datadir5)
    if args.rank == 0: 
        print("Dataset 5: number of data: {}".format(len(datalist5)))
    
    vallist1 = load_decathlon_datalist(jsonlist1, False, "validation", base_dir=datadir1)
    vallist2 = load_decathlon_datalist(jsonlist2, False, "validation", base_dir=datadir2)
    vallist3 = load_decathlon_datalist(jsonlist3, False, "validation", base_dir=datadir3)
    vallist4 = load_decathlon_datalist(jsonlist4, False, "validation", base_dir=datadir4)
    vallist5 = load_decathlon_datalist(jsonlist5, False, "validation", base_dir=datadir5)
    datalist = new_datalist1 + datalist2 + datalist3 + datalist4 + datalist5
    val_files = vallist1 + vallist2 + vallist3 + vallist4 + vallist5

    if args.rank == 0:
        print(f"[{args.rank}] " + "Dataset all training: number of data: {}".format(len(datalist)))
        print(f"[{args.rank}] " + "Dataset all validation: number of data: {}".format(len(val_files)))

    train_transforms = Compose(
        [
            LoadImaged(keys=["image"]),
            AddChanneld(keys=["image"]),
            Orientationd(keys=["image"], axcodes="RAS"),
            ScaleIntensityRanged(
                keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
            ),
            SpatialPadd(keys="image", spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
            CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z]),
            RandSpatialCropSamplesd(
                keys=["image"],
                roi_size=[args.roi_x, args.roi_y, args.roi_z],
                num_samples=args.sw_batch_size,
                random_center=True,
                random_size=False,
            ),
            ToTensord(keys=["image"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image"]),
            AddChanneld(keys=["image"]),
            Orientationd(keys=["image"], axcodes="RAS"),
            ScaleIntensityRanged(
                keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
            ),
            SpatialPadd(keys="image", spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
            CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z]),
            RandSpatialCropSamplesd(
                keys=["image"],
                roi_size=[args.roi_x, args.roi_y, args.roi_z],
                num_samples=args.sw_batch_size,
                random_center=True,
                random_size=False,
            ),
            ToTensord(keys=["image"]),
        ]
    )

    if args.cache_dataset:
        if args.rank == 0:
            print(f"[{args.rank}] " + "Using MONAI Cache Dataset")
        train_ds = CacheDataset(data=datalist, transform=train_transforms, cache_rate=0.5, num_workers=args.num_workers)
    elif args.smartcache_dataset:
        if args.rank == 0:        
            print(f"[{args.rank}] " + "Using MONAI SmartCache Dataset")
        train_ds = SmartCacheDataset(
            data=datalist,
            transform=train_transforms,
            replace_rate=1.0,
            cache_num=2 * args.batch_size * args.sw_batch_size,
        )
    else:
        if args.rank == 0:
            print(f"[{args.rank}] " + "Using generic dataset")
        train_ds = Dataset(data=datalist, transform=train_transforms)
    if args.rank == 0:
        print(f"train dataset = {len(train_ds)}")

    if args.distributed:
        train_sampler = DistributedSampler(dataset=train_ds, even_divisible=True, shuffle=True)
    else:
        train_sampler = None
    # train_loader = DataLoader(
    #     train_ds, batch_size=args.batch_size, num_workers=num_workers, sampler=train_sampler, drop_last=True
    # )
    train_loader = DataLoader(
        train_ds, 
        batch_size=args.batch_size, 
        # shuffle=(train_sampler is None),
        num_workers=args.num_workers, 
        sampler=train_sampler, 
        drop_last=True,
        # pin_memory=True,
    )    

    val_ds = Dataset(data=val_files, transform=val_transforms)

    args.train_ds_len = len(train_ds) # YY 
    args.val_ds_len = len(val_ds)     # YY
    
    if args.distributed:
        val_sampler = DistributedSampler(dataset=val_ds)
    else:
        val_sampler = None
    

    # val_loader = DataLoader(val_ds, batch_size=args.batch_size, num_workers=num_workers, shuffle=False, drop_last=True)
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=(val_sampler is None),
        sampler=val_sampler,
        drop_last=True,
        # pin_memory=True,
    )
    return train_loader, val_loader
